{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "149ec08a",
   "metadata": {},
   "source": [
    "# Lesson 4 - Quantization\n",
    "\n",
    "\n",
    "In this lesson, we'll discuss the concept of \"quantization\". This technique helps reduce the memory overhead of a model and enables running inference with larger LLMs."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7fe394c",
   "metadata": {},
   "source": [
    "### Import required packages and load the LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d9ea926-bb41-42f4-a19f-bf6b1f0a36ee",
   "metadata": {
    "height": 199
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "import time\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from transformers.models.gpt2.modeling_gpt2 import GPT2Model\n",
    "from utils import generate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf64d9c8-a1a5-4dab-8ed5-bff0c8a0a463",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "model_name = \"./models/gpt2\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84649d11-146e-4f52-b999-ed70530cfefe",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "# Define PAD Token = EOS Token = 50256\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "model.config.pad_token_id = model.config.eos_token_id\n",
    "\n",
    "# pad on the left so we can append new tokens on the right\n",
    "tokenizer.padding_side = \"left\"\n",
    "tokenizer.truncation_side = \"left\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f546258",
   "metadata": {},
   "source": [
    "### Define a Float 32 type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23901b5f-5e1b-4784-b495-1bb1a52b3ca9",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "# fix dtype post quantization to \"pretend\" to be fp32\n",
    "def get_float32_dtype(self):\n",
    "    return torch.float32\n",
    "GPT2Model.dtype = property(get_float32_dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5bca2ec-51f9-47d6-b44b-8e39d6032f6c",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model.get_memory_footprint()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73cdde65",
   "metadata": {},
   "source": [
    "### Define a quantization function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5c14e2f-97fe-4238-bc8e-d7c248304e0a",
   "metadata": {
    "height": 335
   },
   "outputs": [],
   "source": [
    "def quantize(t):\n",
    "    # obtain range of values in the tensor to map between 0 and 255\n",
    "    min_val, max_val = t.min(), t.max()\n",
    "\n",
    "    # determine the \"zero-point\", or value in the tensor to map to 0\n",
    "    scale = (max_val - min_val) / 255\n",
    "    zero_point = min_val\n",
    "\n",
    "    # quantize and clamp to ensure we're in [0, 255]\n",
    "    t_quant = (t - zero_point) / scale\n",
    "    t_quant = torch.clamp(t_quant, min=0, max=255)\n",
    "\n",
    "    # keep track of scale and zero_point for reversing quantization\n",
    "    state = (scale, zero_point)\n",
    "\n",
    "    # cast to uint8 and return\n",
    "    t_quant = t_quant.type(torch.uint8)\n",
    "    return t_quant, state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68958427-1202-4010-adb7-04653009f0a0",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "t = model.transformer.h[0].attn.c_attn.weight.data\n",
    "print(t, t.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d31f3c-7cac-49d5-8d01-df6e26ac0af8",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "t_q, state = quantize(t)\n",
    "print(t_q, t_q.min(), t_q.max())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee293fe0",
   "metadata": {},
   "source": [
    "### Define a dequantization function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c56b48f-30fa-43a3-a39d-931168fe5ca0",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "def dequantize(t, state):\n",
    "    scale, zero_point = state\n",
    "    return t.to(torch.float32) * scale + zero_point"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f94fc6-3b3b-4763-9864-d0e63318d63b",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "t_rev = dequantize(t_q, state)\n",
    "print(t_rev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd63d230-7103-4c2c-bc3f-1d58797c97aa",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "torch.abs(t - t_rev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0cc07e7-06d9-411c-8432-85f336f5afd8",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "response_expected = generate(\n",
    "    model,\n",
    "    tokenizer,\n",
    "    [(\"The quick brown fox jumped over the\", 10)]\n",
    ")[0]\n",
    "response_expected"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "490361da",
   "metadata": {},
   "source": [
    "### Let's apply the quantization technique to the entire model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e71577ea-e974-41d1-b1d5-6687a3e6d02e",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "def quantize_model(model):\n",
    "    states = {}\n",
    "    for name, param in model.named_parameters():\n",
    "        param.requires_grad = False\n",
    "        param.data, state = quantize(param.data)\n",
    "        states[name] = state\n",
    "    return model, states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb95a21a-061f-4add-99fc-45ae61aec5f3",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "quant_model, states = quantize_model(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07748809-4f3c-4899-a752-a1a03cdcef34",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "quant_model.get_memory_footprint()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3eed4949-0c4a-4ef6-9db6-4f16eb365827",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "def size_in_bytes(t):\n",
    "    return t.numel() * t.element_size()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bf4e71a-7202-415c-909b-e99584731df7",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "sum([\n",
    "    size_in_bytes(v[0]) + size_in_bytes(v[1])\n",
    "    for v in states.values()\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7506d57b-65cd-4c13-ad55-621acc9829a5",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "def dequantize_model(model, states):\n",
    "    for name, param in model.named_parameters():\n",
    "        state = states[name]\n",
    "        param.data = dequantize(param.data, state)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9228aba0-4e19-45b2-88bf-6902aae129a4",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "dequant_model = dequantize_model(quant_model, states)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e1872f9-9580-4040-a924-53268d76bf8a",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "dequant_model.get_memory_footprint()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9360f264-b765-4160-8ca0-7b60423b1ba6",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "response_expected = generate(\n",
    "    dequant_model,\n",
    "    tokenizer,\n",
    "    [(\"The quick brown fox jumped over the\", 10)]\n",
    ")[0]\n",
    "response_expected"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
